"""
Phase 8: J-Curve Robustness
=============================
Four tests to buttress the demographic-capital J-curve finding:

1. Pre-trends (h=-3,-2,-1): outcomes should NOT trend before inflow shock
2. Placebo instrument: non-demographic gravity predictions should NOT produce J-curve
3. J-curve × institutions: h=2 positive should concentrate in high rule_of_law
4. Dose-response: larger demographic inflows → more pronounced J-curve

All focused on OECD subsample where instruments are strong (F=34.38).
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

GRAVITY_DIR = ROOT_DIR / "gravity_bilateral"

MAX_HORIZON = 5
PRE_HORIZONS = [-3, -2, -1]
CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
    'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
    'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
    'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
}

# Key outcomes for J-curve tests
OUTCOMES = {
    'gross_fixed_investment_gdp': ('Investment/GDP', 'level'),
    'rgdp_growth': ('GDP Growth', 'growth'),
    'mpk_proxy': ('MPK', 'level'),
    'delta_log_kl': ('Δlog K/L', 'growth'),
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def build_horizon_outcome(df, y_var, h, var_type):
    """Build outcome at horizon h (positive or negative)."""
    df = df.sort_values(['iso3', 'year']).copy()

    if h == 0:
        df[f'{y_var}_h0'] = df[y_var]
        return df

    if h > 0:
        if var_type == 'growth':
            # Cumulative: sum from t to t+h
            df[f'{y_var}_h{h}'] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=h+1, min_periods=h+1).sum().shift(-h))
            )
        else:
            # Level: lead by h
            df[f'{y_var}_h{h}'] = df.groupby('iso3')[y_var].shift(-h)
    else:
        # Negative horizon (pre-trends): lag by |h|
        abs_h = abs(h)
        if var_type == 'growth':
            # Cumulative sum from t+h to t-1 (i.e., the |h| periods before t)
            df[f'{y_var}_h{h}'] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=abs_h, min_periods=abs_h).sum().shift(1))
            )
        else:
            # Level: lag by |h|
            df[f'{y_var}_h{h}'] = df.groupby('iso3')[y_var].shift(abs_h)

    return df


def run_lp(df, y_col, x_vars, controls, key_var):
    """Run single LP regression, return key_var results."""
    all_vars = [y_col] + x_vars + controls + ['iso3', 'year']
    sub = df[[c for c in all_vars if c in df.columns]].dropna()
    actual_x = [v for v in x_vars + controls if v in sub.columns]

    if len(sub) < 30:
        return None

    gls = PanelGLS()
    gls.fit(sub[y_col].values, sub[actual_x].values,
            sub['iso3'].values, sub['year'].values)

    idx = actual_x.index(key_var) if key_var in actual_x else None
    if idx is None:
        return None

    return {
        'coef': gls.beta[idx],
        'se': gls.se[idx],
        'p': gls.pvalues[idx],
        'ci_lo': gls.beta[idx] - 1.96 * gls.se[idx],
        'ci_hi': gls.beta[idx] + 1.96 * gls.se[idx],
        'n_obs': gls.n_obs,
        'r_squared': gls.r_squared,
    }


def construct_placebo_instrument():
    """
    Build gravity-predicted inflows using ONLY non-demographic coefficients.
    Strips dZ terms and their KAOPEN interactions from the prediction.
    """
    grav = pd.read_csv(GRAVITY_DIR / "output" / "tables" / "gravity_results.csv")
    model_2c = grav[grav['model'] == '2c: Gravity + Demographics + KAOPEN interactions']

    # Separate gravity-only and demographic coefficients
    gravity_only_vars = ['log_dist', 'contiguity', 'common_lang_official',
                         'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3', 'kaopen_j',
                 'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']

    gravity_coeffs = {}
    for _, row in model_2c.iterrows():
        var = row['variable']
        if var in gravity_only_vars:
            gravity_coeffs[var] = row['coefficient']

    bp = pd.read_csv(GRAVITY_DIR / "data" / "processed" / "bilateral_panel.csv")

    # Check all gravity vars present
    valid = bp.dropna(subset=[v for v in gravity_only_vars if v in bp.columns]).copy()

    # Compute placebo fitted value (gravity-only, no demographics)
    valid['predicted_placebo'] = 0.0
    for var, coef in gravity_coeffs.items():
        if var in valid.columns:
            valid['predicted_placebo'] += coef * valid[var]

    valid['predicted_placebo_level'] = np.exp(valid['predicted_placebo'])

    # Aggregate by recipient-year
    placebo = valid.groupby(['iso_d', 'year']).agg(
        predicted_placebo_inflows=('predicted_placebo_level', 'sum'),
    ).reset_index().rename(columns={'iso_d': 'iso3'})

    placebo['log_predicted_placebo_inflows'] = np.log(
        placebo['predicted_placebo_inflows'].clip(lower=1e-6))

    print(f"Placebo instrument: {len(placebo)} obs, "
          f"{placebo['iso3'].nunique()} countries")
    return placebo[['iso3', 'year', 'log_predicted_placebo_inflows']]


def make_ascii_irf(irf_data, title, width=60):
    """Create ASCII impulse response plot."""
    lines = [title, '=' * len(title), '']

    all_vals = []
    for h, pt in sorted(irf_data.items()):
        if pt and not np.isnan(pt.get('coef', np.nan)):
            all_vals.extend([pt['ci_lo'], pt['ci_hi']])
    if not all_vals:
        return title + '\n  (no data)\n'

    vmin, vmax = min(all_vals), max(all_vals)
    vmin = min(vmin, 0)
    vmax = max(vmax, 0)
    span = vmax - vmin if vmax > vmin else 1

    def pos(val):
        return int((val - vmin) / span * (width - 1))

    zero_pos = pos(0)

    lines.append(f'  h  {"coef":>12s} {"SE":>8s} {"p":>6s}  {"N":>5s}  Plot')
    lines.append(f'  -  {"----":>12s} {"--":>8s} {"---":>6s}  {"---":>5s}  ' + '-' * width)

    for h, pt in sorted(irf_data.items()):
        if pt is None or np.isnan(pt.get('coef', np.nan)):
            lines.append(f' {h:2d}  {"n/a":>12s}')
            continue

        sig = stars(pt['p'])
        coef_str = f"{pt['coef']:.5f}{sig}"

        plot = [' '] * width
        plot[zero_pos] = '|'

        c_pos = min(max(0, pos(pt['coef'])), width - 1)
        lo_pos = max(0, pos(pt['ci_lo']))
        hi_pos = min(width - 1, pos(pt['ci_hi']))

        for i in range(lo_pos, hi_pos + 1):
            plot[i] = '-'
        plot[c_pos] = '*' if pt['p'] < 0.05 else 'o' if pt['p'] < 0.1 else '·'
        if not (lo_pos <= zero_pos <= hi_pos):
            plot[zero_pos] = '|'

        lines.append(f' {h:2d}  {coef_str:>12s} {pt["se"]:8.5f} {pt["p"]:6.4f}  {pt["n_obs"]:5d}  {"".join(plot)}')

    lines.append(f'     {"":>12s} {"":>8s} {"":>6s}  {"":>5s}  ' + '-' * width)
    lines.append(f'     {vmin:>10.4f}{" " * (width - 20)}{vmax:>10.4f}')
    lines.append(f'  (* p<0.05, o p<0.1, · p>0.1, | = zero)')
    lines.append('')
    return '\n'.join(lines)


def test_1_pretrends(df):
    """Test 1: Pre-trends at h=-3,-2,-1. Should be null."""
    print("\n" + "=" * 70)
    print("TEST 1: PRE-TRENDS (h = -3, -2, -1)")
    print("Should show NO significant coefficients if J-curve is causal")
    print("=" * 70)

    results = {}
    all_horizons = PRE_HORIZONS + list(range(MAX_HORIZON + 1))

    for y_var, (y_label, var_type) in OUTCOMES.items():
        if y_var not in df.columns:
            continue

        print(f"\n--- {y_label} ---")
        irf = {}

        for h in all_horizons:
            df_h = build_horizon_outcome(df, y_var, h, var_type)
            y_col = f'{y_var}_h{h}'
            if y_col not in df_h.columns:
                continue

            result = run_lp(df_h, y_col,
                           ['log_predicted_demo_inflows'], CONTROLS,
                           'log_predicted_demo_inflows')
            irf[h] = result

        plot = make_ascii_irf(irf, f'Pre-trends + IRF: Demo Inflows → {y_label} (OECD)')
        print(plot)
        results[y_var] = irf

    return results


def test_2_placebo(df, placebo):
    """Test 2: Placebo instrument (non-demographic gravity). Should NOT show J-curve."""
    print("\n" + "=" * 70)
    print("TEST 2: PLACEBO INSTRUMENT (gravity-only, no demographics)")
    print("Should show WEAKER or NULL J-curve pattern")
    print("=" * 70)

    df_p = df.merge(placebo, on=['iso3', 'year'], how='left')
    n_plac = df_p['log_predicted_placebo_inflows'].notna().sum()
    print(f"Placebo instrument available for {n_plac} obs")

    results = {}

    for y_var, (y_label, var_type) in OUTCOMES.items():
        if y_var not in df_p.columns:
            continue

        print(f"\n--- {y_label} ---")
        irf_demo = {}
        irf_placebo = {}

        for h in range(MAX_HORIZON + 1):
            df_h = build_horizon_outcome(df_p, y_var, h, var_type)
            y_col = f'{y_var}_h{h}'

            # Demographic instrument
            r_demo = run_lp(df_h, y_col,
                           ['log_predicted_demo_inflows'], CONTROLS,
                           'log_predicted_demo_inflows')
            irf_demo[h] = r_demo

            # Placebo instrument
            r_plac = run_lp(df_h, y_col,
                           ['log_predicted_placebo_inflows'], CONTROLS,
                           'log_predicted_placebo_inflows')
            irf_placebo[h] = r_plac

        plot_d = make_ascii_irf(irf_demo, f'DEMOGRAPHIC: Demo Inflows → {y_label}')
        plot_p = make_ascii_irf(irf_placebo, f'PLACEBO: Gravity-only → {y_label}')
        print(plot_d)
        print(plot_p)
        results[y_var] = {'demo': irf_demo, 'placebo': irf_placebo}

    return results


def test_3_institutions(df):
    """Test 3: J-curve × institutions. h=2 positive should be in high rule_of_law."""
    print("\n" + "=" * 70)
    print("TEST 3: J-CURVE × INSTITUTIONS")
    print("Interaction: demo_inflows × rule_of_law at each horizon")
    print("=" * 70)

    if 'rule_of_law' not in df.columns:
        print("WARNING: rule_of_law not available")
        return {}

    # Create interaction
    df = df.copy()
    df['demo_x_rol'] = df['log_predicted_demo_inflows'] * df['rule_of_law']

    results = {}

    for y_var, (y_label, var_type) in OUTCOMES.items():
        if y_var not in df.columns:
            continue

        print(f"\n--- {y_label} ---")
        print(f"  {'h':>3s}  {'demo β':>10s}  {'RoL β':>10s}  {'interact β':>12s}  {'int. p':>8s}  {'N':>5s}")
        print(f"  {'─'*3}  {'─'*10}  {'─'*10}  {'─'*12}  {'─'*8}  {'─'*5}")

        irf_base = {}
        irf_interact = {}

        for h in range(MAX_HORIZON + 1):
            df_h = build_horizon_outcome(df, y_var, h, var_type)
            y_col = f'{y_var}_h{h}'

            all_vars = [y_col, 'log_predicted_demo_inflows', 'rule_of_law',
                        'demo_x_rol'] + CONTROLS + ['iso3', 'year']
            sub = df_h[[c for c in all_vars if c in df_h.columns]].dropna()
            actual_x = ['log_predicted_demo_inflows', 'rule_of_law', 'demo_x_rol']
            actual_x += [c for c in CONTROLS if c in sub.columns]

            if len(sub) < 30:
                print(f"  {h:3d}  insufficient obs")
                continue

            gls = PanelGLS()
            gls.fit(sub[y_col].values, sub[actual_x].values,
                    sub['iso3'].values, sub['year'].values)

            idx_demo = actual_x.index('log_predicted_demo_inflows')
            idx_rol = actual_x.index('rule_of_law')
            idx_int = actual_x.index('demo_x_rol')

            sig = stars(gls.pvalues[idx_int])
            print(f"  {h:3d}  {gls.beta[idx_demo]:10.5f}  {gls.beta[idx_rol]:10.5f}  "
                  f"{gls.beta[idx_int]:12.5f}{sig:3s}  {gls.pvalues[idx_int]:8.4f}  {gls.n_obs:5d}")

            irf_interact[h] = {
                'coef': gls.beta[idx_int],
                'se': gls.se[idx_int],
                'p': gls.pvalues[idx_int],
                'ci_lo': gls.beta[idx_int] - 1.96 * gls.se[idx_int],
                'ci_hi': gls.beta[idx_int] + 1.96 * gls.se[idx_int],
                'n_obs': gls.n_obs,
                'r_squared': gls.r_squared,
                'demo_coef': gls.beta[idx_demo],
                'rol_coef': gls.beta[idx_rol],
            }

        plot = make_ascii_irf(irf_interact,
                              f'Interaction (demo × RoL) → {y_label}')
        print(plot)
        results[y_var] = irf_interact

    return results


def test_4_dose_response(df):
    """Test 4: Split at median demographic inflow. High dose → stronger J-curve."""
    print("\n" + "=" * 70)
    print("TEST 4: DOSE-RESPONSE")
    print("Split OECD at median predicted_demo_inflows")
    print("=" * 70)

    # Compute country-level median inflow
    country_means = (df.groupby('iso3')['log_predicted_demo_inflows']
                     .mean().dropna())
    median_inflow = country_means.median()
    high_dose = set(country_means[country_means >= median_inflow].index)
    low_dose = set(country_means[country_means < median_inflow].index)

    print(f"Median log_predicted_demo_inflows: {median_inflow:.3f}")
    print(f"High-dose countries ({len(high_dose)}): {sorted(high_dose)}")
    print(f"Low-dose countries ({len(low_dose)}): {sorted(low_dose)}")

    df_high = df[df['iso3'].isin(high_dose)].copy()
    df_low = df[df['iso3'].isin(low_dose)].copy()

    results = {}

    for y_var, (y_label, var_type) in OUTCOMES.items():
        if y_var not in df.columns:
            continue

        print(f"\n--- {y_label} ---")
        print(f"  {'h':>3s}  {'HIGH β':>12s}  {'HIGH p':>8s}  {'LOW β':>12s}  {'LOW p':>8s}  {'HIGH N':>6s}  {'LOW N':>6s}")
        print(f"  {'─'*3}  {'─'*12}  {'─'*8}  {'─'*12}  {'─'*8}  {'─'*6}  {'─'*6}")

        irf_high = {}
        irf_low = {}

        for h in range(MAX_HORIZON + 1):
            df_h_high = build_horizon_outcome(df_high, y_var, h, var_type)
            df_h_low = build_horizon_outcome(df_low, y_var, h, var_type)
            y_col = f'{y_var}_h{h}'

            r_high = run_lp(df_h_high, y_col,
                           ['log_predicted_demo_inflows'], CONTROLS,
                           'log_predicted_demo_inflows')
            r_low = run_lp(df_h_low, y_col,
                          ['log_predicted_demo_inflows'], CONTROLS,
                          'log_predicted_demo_inflows')

            irf_high[h] = r_high
            irf_low[h] = r_low

            hb = f"{r_high['coef']:.5f}{stars(r_high['p'])}" if r_high else 'n/a'
            hp = f"{r_high['p']:.4f}" if r_high else ''
            hn = f"{r_high['n_obs']}" if r_high else ''
            lb = f"{r_low['coef']:.5f}{stars(r_low['p'])}" if r_low else 'n/a'
            lp = f"{r_low['p']:.4f}" if r_low else ''
            ln = f"{r_low['n_obs']}" if r_low else ''
            print(f"  {h:3d}  {hb:>12s}  {hp:>8s}  {lb:>12s}  {lp:>8s}  {hn:>6s}  {ln:>6s}")

        plot_h = make_ascii_irf(irf_high, f'HIGH DOSE: → {y_label}')
        plot_l = make_ascii_irf(irf_low, f'LOW DOSE: → {y_label}')
        print(plot_h)
        print(plot_l)
        results[y_var] = {'high': irf_high, 'low': irf_low}

    return results


def save_results(pretrend_res, placebo_res, instit_res, dose_res):
    """Save all results to CSV and markdown."""
    rows = []

    # Pre-trends
    for y_var, irf in pretrend_res.items():
        for h, pt in irf.items():
            if pt:
                rows.append({
                    'test': 'pre-trends', 'outcome': y_var, 'horizon': h,
                    'spec': 'demographic', **pt
                })

    # Placebo
    for y_var, data in placebo_res.items():
        for spec, irf in data.items():
            for h, pt in irf.items():
                if pt:
                    rows.append({
                        'test': 'placebo', 'outcome': y_var, 'horizon': h,
                        'spec': spec, **pt
                    })

    # Institutions
    for y_var, irf in instit_res.items():
        for h, pt in irf.items():
            if pt:
                rows.append({
                    'test': 'institutions', 'outcome': y_var, 'horizon': h,
                    'spec': 'interaction', **pt
                })

    # Dose-response
    for y_var, data in dose_res.items():
        for spec, irf in data.items():
            for h, pt in irf.items():
                if pt:
                    rows.append({
                        'test': 'dose-response', 'outcome': y_var, 'horizon': h,
                        'spec': spec, **pt
                    })

    df_out = pd.DataFrame(rows)
    df_out.to_csv(OUT_TABLES / "jcurve_robustness_results.csv", index=False)

    # Markdown summary
    with open(OUT_TABLES / "jcurve_robustness.md", 'w') as f:
        f.write("# Table 8: J-Curve Robustness Tests (OECD Subsample)\n\n")

        # Test 1: Pre-trends
        f.write("## Test 1: Pre-Trends\n\n")
        f.write("Coefficients at h<0 should be insignificant if J-curve is causal.\n\n")
        f.write("| Outcome | h=-3 | h=-2 | h=-1 | h=0 | h=1 | h=2 | h=3 |\n")
        f.write("|---------|------|------|------|-----|-----|-----|-----|\n")
        for y_var, (y_label, _) in OUTCOMES.items():
            irf = pretrend_res.get(y_var, {})
            cells = []
            for h in [-3, -2, -1, 0, 1, 2, 3]:
                pt = irf.get(h)
                if pt:
                    cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                else:
                    cells.append('')
            f.write(f"| {y_label} | {' | '.join(cells)} |\n")
        f.write("\n")

        # Test 2: Placebo
        f.write("## Test 2: Placebo Instrument (Gravity-Only vs Demographic)\n\n")
        f.write("| Outcome | h | Demo β | Demo p | Placebo β | Placebo p |\n")
        f.write("|---------|---|--------|--------|-----------|----------|\n")
        for y_var, (y_label, _) in OUTCOMES.items():
            data = placebo_res.get(y_var, {})
            for h in range(MAX_HORIZON + 1):
                d = data.get('demo', {}).get(h)
                p = data.get('placebo', {}).get(h)
                db = f"{d['coef']:.4f}{stars(d['p'])}" if d else ''
                dp = f"{d['p']:.4f}" if d else ''
                pb = f"{p['coef']:.4f}{stars(p['p'])}" if p else ''
                pp = f"{p['p']:.4f}" if p else ''
                f.write(f"| {y_label} | {h} | {db} | {dp} | {pb} | {pp} |\n")
        f.write("\n")

        # Test 3: Institutions
        f.write("## Test 3: J-Curve × Rule of Law Interaction\n\n")
        f.write("| Outcome | h | Interaction β | p | Base β | RoL β |\n")
        f.write("|---------|---|--------------|---|--------|-------|\n")
        for y_var, (y_label, _) in OUTCOMES.items():
            irf = instit_res.get(y_var, {})
            for h in range(MAX_HORIZON + 1):
                pt = irf.get(h)
                if pt:
                    f.write(f"| {y_label} | {h} | {pt['coef']:.4f}{stars(pt['p'])} "
                            f"| {pt['p']:.4f} | {pt.get('demo_coef', 0):.4f} "
                            f"| {pt.get('rol_coef', 0):.4f} |\n")
        f.write("\n")

        # Test 4: Dose-response
        f.write("## Test 4: Dose-Response (High vs Low Demographic Inflows)\n\n")
        f.write("| Outcome | h | High β | High p | Low β | Low p |\n")
        f.write("|---------|---|--------|--------|-------|-------|\n")
        for y_var, (y_label, _) in OUTCOMES.items():
            data = dose_res.get(y_var, {})
            for h in range(MAX_HORIZON + 1):
                hi = data.get('high', {}).get(h)
                lo = data.get('low', {}).get(h)
                hb = f"{hi['coef']:.4f}{stars(hi['p'])}" if hi else ''
                hp = f"{hi['p']:.4f}" if hi else ''
                lb = f"{lo['coef']:.4f}{stars(lo['p'])}" if lo else ''
                lp = f"{lo['p']:.4f}" if lo else ''
                f.write(f"| {y_label} | {h} | {hb} | {hp} | {lb} | {lp} |\n")
        f.write("\n")

        f.write("*p<0.1, **p<0.05, ***p<0.01. OECD subsample. PanelGLS with AR(1).*\n")

    print(f"\nSaved: {OUT_TABLES / 'jcurve_robustness_results.csv'}")
    print(f"Saved: {OUT_TABLES / 'jcurve_robustness.md'}")


def main():
    print("=" * 70)
    print("PHASE 8: J-CURVE ROBUSTNESS TESTS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    df_oecd = df[df['iso3'].isin(OECD)].copy()
    print(f"OECD panel: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")

    # Test 1: Pre-trends
    pretrend_res = test_1_pretrends(df_oecd)

    # Test 2: Placebo instrument
    print("\nConstructing placebo instrument...")
    placebo = construct_placebo_instrument()
    placebo_res = test_2_placebo(df_oecd, placebo)

    # Test 3: J-curve × institutions
    instit_res = test_3_institutions(df_oecd)

    # Test 4: Dose-response
    dose_res = test_4_dose_response(df_oecd)

    # Save
    print("\n--- Saving results ---")
    save_results(pretrend_res, placebo_res, instit_res, dose_res)

    # Summary assessment
    print("\n" + "=" * 70)
    print("SUMMARY ASSESSMENT")
    print("=" * 70)

    # Pre-trends: count significant coefficients at h<0
    pre_sig = 0
    pre_total = 0
    for y_var, irf in pretrend_res.items():
        for h in PRE_HORIZONS:
            pt = irf.get(h)
            if pt:
                pre_total += 1
                if pt['p'] < 0.1:
                    pre_sig += 1
    print(f"\nPre-trends: {pre_sig}/{pre_total} significant at 10% "
          f"({'PASS ✓' if pre_sig <= 2 else 'CONCERN'})")

    # Placebo: compare demo vs placebo significance pattern
    demo_sig_h2 = 0
    plac_sig_h2 = 0
    for y_var, data in placebo_res.items():
        d = data.get('demo', {}).get(2)
        p = data.get('placebo', {}).get(2)
        if d and d['p'] < 0.1:
            demo_sig_h2 += 1
        if p and p['p'] < 0.1:
            plac_sig_h2 += 1
    print(f"Placebo at h=2: {demo_sig_h2} demo sig vs {plac_sig_h2} placebo sig "
          f"({'PASS ✓' if demo_sig_h2 > plac_sig_h2 else 'MIXED'})")

    # Institutions: positive interaction at h=2?
    inst_h2_pos = 0
    for y_var, irf in instit_res.items():
        pt = irf.get(2)
        if pt and pt['coef'] > 0 and pt['p'] < 0.1:
            inst_h2_pos += 1
    print(f"Institutions × demo at h=2: {inst_h2_pos} outcomes with positive sig interaction")

    # Dose-response: high > low at h=2?
    dose_mono = 0
    for y_var, data in dose_res.items():
        hi = data.get('high', {}).get(2)
        lo = data.get('low', {}).get(2)
        if hi and lo and abs(hi['coef']) > abs(lo['coef']):
            dose_mono += 1
    print(f"Dose-response at h=2: {dose_mono}/{len(dose_res)} outcomes show monotonicity")

    print("\nPhase 8 complete.")


if __name__ == '__main__':
    main()
